package io.kiki.sba.registry.task;


import io.kiki.sba.registry.metrics.TaskMetrics;
import io.kiki.sba.registry.util.ConcurrentUtils;
import io.prometheus.client.Counter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.concurrent.BlockingQueue;
import java.util.concurrent.TimeUnit;

/**
 * thread unsafe, could not use concurrently
 */
public class KeyedThreadPoolExecutor {
    private static final Logger logger = LoggerFactory.getLogger(KeyedThreadPoolExecutor.class);
    protected final String executorName;
    protected final int coreBufferSize;
    protected final int coreSize;
    private final AbstractWorker[] workers;
    private final Counter taskCounter;

    public KeyedThreadPoolExecutor(String executorName, int coreSize, int coreBufferSize) {
        this.executorName = executorName;
        this.coreBufferSize = coreBufferSize;
        this.coreSize = coreSize;
        this.taskCounter = Counter.build().namespace("keyedExecutor").help("metrics for keyed executor").name(executorName.replace('-', '_') + "_task_total").labelNames("idx", "type").register();

        workers = createWorkers(coreSize, coreBufferSize);
        for (int i = 0; i < coreSize; i++) {
            ConcurrentUtils.createDaemonThread(executorName + "_" + i, workers[i]).start();
        }
        TaskMetrics.getInstance().registerKeyThreadExecutor("KeyedExecutor-" + executorName, this);
    }

    protected AbstractWorker[] createWorkers(int coreSize, int coreBufferSize) {
        BlockingQueues<KeyedTask> queues = new BlockingQueues<>(coreSize, coreBufferSize, false);
        AbstractWorker[] workers = new AbstractWorker[coreSize];
        for (int i = 0; i < coreSize; i++) {
            workers[i] = new WorkerImpl(i, queues);
        }
        return workers;
    }

    public int getQueueSize() {
        int size = 0;
        for (Worker w : workers) {
            size += w.size();
        }
        return size;
    }

    public int getActiveCount() {
        int count = 0;
        for (AbstractWorker w : workers) {
            if (w.isActive()) {
                count++;
            }
        }
        return count;
    }

    public long getTaskCount() {
        long count = 0;
        for (AbstractWorker w : workers) {
            count += w.workerCommitCounter.get();
        }
        return count;
    }

    public long getCompletedTaskCount() {
        long count = 0;
        for (AbstractWorker w : workers) {
            count += w.workerExecCounter.get();
        }
        return count;
    }

    public int getCoreSize() {
        return coreSize;
    }

    public <T extends Runnable> KeyedTask<T> execute(Object key, T runnable) {
        KeyedTask task = new KeyedTask(key, runnable);
        AbstractWorker w = workerOf(key);
        // should not happen,
        if (!w.offer(task)) {
            throw new FastRejectedExecutionException(String.format("%s_%d full, max=%d, now=%d", executorName, w.idx, coreBufferSize, w.size()));
        }
        w.workerCommitCounter.inc();
        return task;
    }

    private AbstractWorker workerOf(Object key) {
        int n = (key.hashCode() & 0x7fffffff) % workers.length;
        return workers[n];
    }

    protected interface Worker extends Runnable {
        int size();

        KeyedTask poll() throws InterruptedException;

        boolean offer(KeyedTask task);
    }

    private final class WorkerImpl extends AbstractWorker {
        final BlockingQueues<KeyedTask> queues;
        final BlockingQueue<KeyedTask> queue;

        WorkerImpl(int idx, BlockingQueues<KeyedTask> queues) {
            super(idx);
            this.queues = queues;
            this.queue = queues.getQueue(idx);
        }

        public int size() {
            return queue.size();
        }

        public KeyedTask poll() throws InterruptedException {
            return queue.poll(180, TimeUnit.SECONDS);
        }

        public boolean offer(KeyedTask task) {
            return queues.offer(idx, task);
        }
    }

    protected abstract class AbstractWorker<T> implements Worker {
        final int idx;
        final Counter.Child workerExecCounter;
        final Counter.Child workerCommitCounter;
        volatile boolean running;

        protected AbstractWorker(int idx) {
            this.idx = idx;
            this.workerExecCounter = taskCounter.labels(String.valueOf(idx), "exec");
            this.workerCommitCounter = taskCounter.labels(String.valueOf(idx), "commit");
        }

        @Override
        public void run() {
            for (; ; ) {
                try {
                    final KeyedTask task = poll();
                    if (task == null) {
                        logger.info("{}_{} idle", executorName, idx);
                        continue;
                    }
                    running = true;
                    task.run();
                    workerExecCounter.inc();
                } catch (Throwable e) {
                    logger.error("{}_{} run task error", executorName, idx, e);
                } finally {
                    running = false;
                }
            }
        }

        protected boolean isActive() {
            return running;
        }
    }
}
